{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 语义块切分 Semantic Chunking\n",
    "\n",
    "文本块切分是检索增强生成（RAG）中的关键步骤，其中将大文本体分割成有意义的段落以提高检索准确性。与固定长度块切分不同，语义块切分是根据句子之间的内容相似性来分割文本的。\n",
    "\n",
    "------\n",
    "分块断点方法：\n",
    "- **Percentile（百分位）**: 找到所有相似度差异小于设定的百分点的内容，并在超过此值的下降处分割块。\n",
    "- **Standard Deviation（标准差）**: 相似度超过 X 个标准差以下的分割点。\n",
    "- **Interquartile Range (IQR)（四分位距）**: 使用四分位距（Q3 - Q1）来确定分割点。\n",
    "\n",
    "------\n",
    "实现步骤：\n",
    "- 从PDF文件中提取文本：按句子进行切分\n",
    "- 提取的文本创建语义分块：\n",
    "    - 将前后两个相邻的句子计算相似度\n",
    "    - 根据相似度下降计算分块的断点，断点方法有三种：百分位、标准差和四分位距\n",
    "    - 然后根据断点分割文本，得到语义块\n",
    "- 创建嵌入\n",
    "- 根据查询并检索文档\n",
    "- 将检索出来的文本用于模型生成回答"
   ],
   "id": "7f73470ceb6cf395"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 设置环境",
   "id": "4471f27733aeec46"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:06:30.347847Z",
     "start_time": "2025-04-23T06:06:28.953369Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import fitz\n",
    "import os\n",
    "import numpy as np\n",
    "import json\n",
    "from openai import OpenAI\n",
    "from dotenv import load_dotenv\n",
    "# from sentence_transformers import SentenceTransformer, util\n",
    "# from typing import List\n",
    "\n",
    "load_dotenv()"
   ],
   "id": "e47aaacb8642f486",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 从 PDF 文件中提取文本",
   "id": "f349543f476dd541"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:06:32.505973Z",
     "start_time": "2025-04-23T06:06:32.449642Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def extract_text_from_pdf(pdf_path):\n",
    "    \"\"\"\n",
    "    Extracts text from a PDF file.\n",
    "\n",
    "    Args:\n",
    "    pdf_path (str): Path to the PDF file.\n",
    "\n",
    "    Returns:\n",
    "    str: Extracted text from the PDF.\n",
    "    \"\"\"\n",
    "    # Open the PDF file\n",
    "    mypdf = fitz.open(pdf_path)\n",
    "    all_text = \"\"  # Initialize an empty string to store the extracted text\n",
    "\n",
    "    # Iterate through each page in the PDF\n",
    "    for page in mypdf:\n",
    "        # Extract text from the current page and add spacing\n",
    "        all_text += page.get_text(\"text\") + \" \"\n",
    "\n",
    "    # Return the extracted text, stripped of leading/trailing whitespace\n",
    "    return all_text.strip()\n",
    "\n",
    "# Define the path to the PDF file\n",
    "pdf_path = \"../../data/AI_Information.en.zh-CN.pdf\"\n",
    "\n",
    "# Extract text from the PDF file\n",
    "extracted_text = extract_text_from_pdf(pdf_path)\n",
    "\n",
    "# Print the first 500 characters of the extracted text\n",
    "print(extracted_text[:500])"
   ],
   "id": "b5c7fe6527773893",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "理解⼈⼯智能\n",
      "第⼀章：⼈⼯智能简介\n",
      "⼈⼯智能 (AI) 是指数字计算机或计算机控制的机器⼈执⾏通常与智能⽣物相关的任务的能⼒。该术\n",
      "语通常⽤于开发具有⼈类特有的智⼒过程的系统，例如推理、发现意义、概括或从过往经验中学习\n",
      "的能⼒。在过去的⼏⼗年中，计算能⼒和数据可⽤性的进步显著加速了⼈⼯智能的开发和部署。\n",
      "历史背景\n",
      "⼈⼯智能的概念已存在数个世纪，经常出现在神话和⼩说中。然⽽，⼈⼯智能研究的正式领域始于\n",
      "20世纪中叶。1956年的达特茅斯研讨会被⼴泛认为是⼈⼯智能的发源地。早期的⼈⼯智能研究侧\n",
      "重于问题解决和符号⽅法。20世纪80年代专家系统兴起，⽽20世纪90年代和21世纪初，机器学习\n",
      "和神经⽹络取得了进步。深度学习的最新突破彻底改变了这⼀领域。\n",
      "现代观察\n",
      "现代⼈⼯智能系统在⽇常⽣活中⽇益普及。从 Siri 和 Alexa 等虚拟助⼿，到流媒体服务和社交媒体\n",
      "上的推荐算法，⼈⼯智能正在影响我们的⽣活、⼯作和互动⽅式。⾃动驾驶汽⻋、先进的医疗诊断\n",
      "技术以及复杂的⾦融建模⼯具的发展，彰显了⼈⼯智能应⽤的⼴泛性和持续增⻓。此外，⼈们对其\n",
      "伦理影响、偏⻅和失业的担忧也⽇益凸显。\n",
      "第⼆章：⼈⼯智能\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 设置 OpenAI API 客户端\n",
    "\n",
    "初始化 OpenAI 客户端以生成嵌入和响应"
   ],
   "id": "f06262175de201bc"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:06:36.218029Z",
     "start_time": "2025-04-23T06:06:35.833946Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 国内支持类OpenAI的API都可，需要配置对应的base_url和api_key\n",
    "\n",
    "client = OpenAI(\n",
    "    base_url=os.getenv(\"LLM_BASE_URL\"),\n",
    "    api_key=os.getenv(\"LLM_API_KEY\")\n",
    ")"
   ],
   "id": "6808757a46532a2c",
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 创建句子级嵌入(sentence-level embedding)\n",
    "\n",
    "将文本分割成句子并生成嵌入"
   ],
   "id": "4395af2c1e1a4774"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:06:58.867715Z",
     "start_time": "2025-04-23T06:06:38.122129Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def get_embedding(text):\n",
    "    response = client.embeddings.create(\n",
    "        model=os.getenv(\"EMBEDDING_MODEL_ID\"),\n",
    "        input=text\n",
    "    )\n",
    "    return np.array(response.data[0].embedding)\n",
    "\n",
    "# Splitting text into sentences (basic split)\n",
    "sentences = extracted_text.split(\"。\")\n",
    "print(len(sentences))\n",
    "# Generate embeddings for each sentence\n",
    "# FIXME: 因为PDF切分的最后一个句子为空，需要过滤掉\n",
    "embeddings = [get_embedding(sentence) for sentence in sentences if sentence]\n",
    "print(f\"Generated {len(embeddings)} sentence embeddings.\")\n",
    "\n",
    "# def get_embedding(text, model_path: str = \"../rag_naive/model/gte-base-zh\"):\n",
    "#     \"\"\"\n",
    "#     Creates an embedding for the given text using OpenAI.\n",
    "#\n",
    "#     Args:\n",
    "#     text (str): Input text.\n",
    "#     model (str): Embedding model name.\n",
    "#\n",
    "#     Returns:\n",
    "#     np.ndarray: The embedding vector.\n",
    "#     \"\"\"\n",
    "#     st_model = SentenceTransformer(model_name_or_path=model_path)\n",
    "#     st_embeddings = st_model.encode(text, normalize_embeddings=True)\n",
    "#     response = [embedding.tolist() for embedding in st_embeddings]\n",
    "#\n",
    "#     return np.array(response)\n",
    "#\n",
    "# # Splitting text into sentences (basic split)\n",
    "# sentences = extracted_text.split(\"。\")\n",
    "# print(len(sentences))\n",
    "#\n",
    "# # Generate embeddings for each sentence\n",
    "# embeddings = [get_embedding(sentence) for sentence in sentences]\n",
    "#\n",
    "# print(f\"Generated {len(embeddings)} sentence embeddings.\")"
   ],
   "id": "331379b35781bb4d",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "257\n",
      "Generated 256 sentence embeddings.\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 计算相似度差异\n",
    "\n",
    "计算连续句子的余弦相似度"
   ],
   "id": "34426aed5796c647"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:07:00.685858Z",
     "start_time": "2025-04-23T06:07:00.676123Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def cosine_similarity(vec1, vec2):\n",
    "    \"\"\"\n",
    "    Computes cosine similarity between two vectors.\n",
    "\n",
    "    Args:\n",
    "    vec1 (np.ndarray): First vector.\n",
    "    vec2 (np.ndarray): Second vector.\n",
    "\n",
    "    Returns:\n",
    "    float: Cosine similarity.\n",
    "    \"\"\"\n",
    "    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))\n",
    "\n",
    "# Compute similarity between consecutive sentences\n",
    "similarities = [cosine_similarity(embeddings[i], embeddings[i + 1]) for i in range(len(embeddings) - 1)]"
   ],
   "id": "dea93adcd088d5ba",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 实现语义分块 Semantic Chunking\n",
    "\n",
    "实现了三种不同的方法来查找断点"
   ],
   "id": "12d67649221a37b5"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:07:02.746228Z",
     "start_time": "2025-04-23T06:07:02.729795Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def compute_breakpoints(similarities, method=\"percentile\", threshold=90):\n",
    "    \"\"\"\n",
    "    根据相似度下降计算分块的断点。\n",
    "\n",
    "    Args:\n",
    "        similarities (List[float]): 句子之间的相似度分数列表。\n",
    "        method (str): 'percentile'（百分位）、'standard_deviation'（标准差）或 'interquartile'（四分位距）。\n",
    "        threshold (float): 阈值（对于 'percentile' 是百分位数，对于 'standard_deviation' 是标准差倍数）。\n",
    "\n",
    "    Returns:\n",
    "        List[int]: 分块的索引列表。\n",
    "    \"\"\"\n",
    "    # 根据选定的方法确定阈值\n",
    "    if method == \"percentile\":\n",
    "        # 计算相似度分数的第 X 百分位数\n",
    "        threshold_value = np.percentile(similarities, threshold)\n",
    "    elif method == \"standard_deviation\":\n",
    "        # 计算相似度分数的均值和标准差。\n",
    "        mean = np.mean(similarities)\n",
    "        std_dev = np.std(similarities)\n",
    "        # 将阈值设置为均值减去 X 倍的标准差\n",
    "        threshold_value = mean - (threshold * std_dev)\n",
    "    elif method == \"interquartile\":\n",
    "        # 计算第一和第三四分位数（Q1 和 Q3）。\n",
    "        q1, q3 = np.percentile(similarities, [25, 75])\n",
    "        # 使用 IQR 规则（四分位距规则）设置阈值\n",
    "        threshold_value = q1 - 1.5 * (q3 - q1)\n",
    "    else:\n",
    "        # 如果提供了无效的方法，则抛出异常\n",
    "        raise ValueError(\"Invalid method. Choose 'percentile', 'standard_deviation', or 'interquartile'.\")\n",
    "\n",
    "    # 找出相似度低于阈值的索引\n",
    "    return [i for i, sim in enumerate(similarities) if sim < threshold_value]\n",
    "\n",
    "# 使用百分位法计算断点，阈值为90\n",
    "breakpoints = compute_breakpoints(similarities, method=\"percentile\", threshold=90)\n",
    "breakpoints"
   ],
   "id": "2cb8614c65e7fcc",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 4,\n",
       " 5,\n",
       " 6,\n",
       " 7,\n",
       " 8,\n",
       " 9,\n",
       " 10,\n",
       " 11,\n",
       " 12,\n",
       " 13,\n",
       " 14,\n",
       " 15,\n",
       " 16,\n",
       " 17,\n",
       " 18,\n",
       " 19,\n",
       " 20,\n",
       " 21,\n",
       " 22,\n",
       " 23,\n",
       " 24,\n",
       " 25,\n",
       " 26,\n",
       " 27,\n",
       " 28,\n",
       " 29,\n",
       " 30,\n",
       " 32,\n",
       " 33,\n",
       " 34,\n",
       " 35,\n",
       " 36,\n",
       " 37,\n",
       " 38,\n",
       " 39,\n",
       " 40,\n",
       " 41,\n",
       " 42,\n",
       " 44,\n",
       " 45,\n",
       " 46,\n",
       " 48,\n",
       " 49,\n",
       " 50,\n",
       " 52,\n",
       " 53,\n",
       " 54,\n",
       " 55,\n",
       " 56,\n",
       " 57,\n",
       " 58,\n",
       " 59,\n",
       " 60,\n",
       " 61,\n",
       " 62,\n",
       " 63,\n",
       " 65,\n",
       " 66,\n",
       " 67,\n",
       " 68,\n",
       " 69,\n",
       " 70,\n",
       " 71,\n",
       " 72,\n",
       " 73,\n",
       " 74,\n",
       " 76,\n",
       " 77,\n",
       " 78,\n",
       " 79,\n",
       " 81,\n",
       " 82,\n",
       " 83,\n",
       " 84,\n",
       " 85,\n",
       " 86,\n",
       " 87,\n",
       " 88,\n",
       " 90,\n",
       " 91,\n",
       " 92,\n",
       " 93,\n",
       " 94,\n",
       " 95,\n",
       " 96,\n",
       " 97,\n",
       " 98,\n",
       " 99,\n",
       " 100,\n",
       " 101,\n",
       " 102,\n",
       " 103,\n",
       " 105,\n",
       " 106,\n",
       " 107,\n",
       " 109,\n",
       " 110,\n",
       " 111,\n",
       " 112,\n",
       " 113,\n",
       " 114,\n",
       " 115,\n",
       " 116,\n",
       " 117,\n",
       " 118,\n",
       " 119,\n",
       " 120,\n",
       " 121,\n",
       " 122,\n",
       " 124,\n",
       " 125,\n",
       " 127,\n",
       " 129,\n",
       " 130,\n",
       " 131,\n",
       " 132,\n",
       " 133,\n",
       " 134,\n",
       " 135,\n",
       " 136,\n",
       " 137,\n",
       " 138,\n",
       " 139,\n",
       " 140,\n",
       " 141,\n",
       " 142,\n",
       " 143,\n",
       " 144,\n",
       " 145,\n",
       " 147,\n",
       " 148,\n",
       " 149,\n",
       " 150,\n",
       " 151,\n",
       " 152,\n",
       " 154,\n",
       " 155,\n",
       " 157,\n",
       " 159,\n",
       " 161,\n",
       " 162,\n",
       " 163,\n",
       " 164,\n",
       " 165,\n",
       " 166,\n",
       " 167,\n",
       " 168,\n",
       " 169,\n",
       " 170,\n",
       " 171,\n",
       " 172,\n",
       " 173,\n",
       " 174,\n",
       " 175,\n",
       " 176,\n",
       " 177,\n",
       " 179,\n",
       " 180,\n",
       " 181,\n",
       " 182,\n",
       " 183,\n",
       " 184,\n",
       " 185,\n",
       " 186,\n",
       " 187,\n",
       " 188,\n",
       " 189,\n",
       " 190,\n",
       " 191,\n",
       " 192,\n",
       " 193,\n",
       " 194,\n",
       " 195,\n",
       " 197,\n",
       " 199,\n",
       " 200,\n",
       " 201,\n",
       " 202,\n",
       " 203,\n",
       " 204,\n",
       " 205,\n",
       " 206,\n",
       " 207,\n",
       " 208,\n",
       " 209,\n",
       " 211,\n",
       " 212,\n",
       " 213,\n",
       " 214,\n",
       " 215,\n",
       " 216,\n",
       " 217,\n",
       " 218,\n",
       " 219,\n",
       " 220,\n",
       " 221,\n",
       " 223,\n",
       " 224,\n",
       " 225,\n",
       " 226,\n",
       " 227,\n",
       " 228,\n",
       " 229,\n",
       " 230,\n",
       " 231,\n",
       " 232,\n",
       " 233,\n",
       " 234,\n",
       " 236,\n",
       " 237,\n",
       " 238,\n",
       " 239,\n",
       " 240,\n",
       " 241,\n",
       " 243,\n",
       " 244,\n",
       " 245,\n",
       " 246,\n",
       " 247,\n",
       " 248,\n",
       " 249,\n",
       " 250,\n",
       " 251,\n",
       " 252,\n",
       " 253]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 将文本分割成语义块\n",
    "\n",
    "将文本基于断点进行分割"
   ],
   "id": "d3332d4ad7b800a7"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:07:12.222705Z",
     "start_time": "2025-04-23T06:07:12.215242Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def split_into_chunks(sentences, breakpoints):\n",
    "    \"\"\"\n",
    "    将句子分割为语义块\n",
    "\n",
    "    Args:\n",
    "    sentences (List[str]): 句子列表\n",
    "    breakpoints (List[int]): 进行分块的索引位置\n",
    "\n",
    "    Returns:\n",
    "    List[str]: 文本块列表\n",
    "    \"\"\"\n",
    "    chunks = []  # Initialize an empty list to store the chunks\n",
    "    start = 0  # Initialize the start index\n",
    "\n",
    "    # 遍历每个断点以创建块\n",
    "    for bp in breakpoints:\n",
    "        # 将从起始位置到当前断点的句子块追加到列表中\n",
    "        chunks.append(\"。\".join(sentences[start:bp + 1]) + \"。\")\n",
    "        start = bp + 1  # 将起始索引更新为断点后的下一个句子\n",
    "\n",
    "    # 将剩余的句子作为最后一个块追加\n",
    "    chunks.append(\"。\".join(sentences[start:]))\n",
    "    return chunks  # Return the list of chunks\n",
    "\n",
    "# split_into_chunks 函数创建文本块\n",
    "text_chunks = split_into_chunks(sentences, breakpoints)\n",
    "\n",
    "# Print the number of chunks created\n",
    "print(f\"Number of semantic chunks: {len(text_chunks)}\")\n",
    "\n",
    "# Print the first chunk to verify the result\n",
    "print(\"\\nFirst text chunk:\")\n",
    "print(text_chunks[0])"
   ],
   "id": "453a4c0a446643b6",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of semantic chunks: 230\n",
      "\n",
      "First text chunk:\n",
      "理解⼈⼯智能\n",
      "第⼀章：⼈⼯智能简介\n",
      "⼈⼯智能 (AI) 是指数字计算机或计算机控制的机器⼈执⾏通常与智能⽣物相关的任务的能⼒。\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 语义块创建嵌入\n",
    "\n",
    "为每个片段创建嵌入，以便后续检索"
   ],
   "id": "b28133ccf6bbc698"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:07:36.513429Z",
     "start_time": "2025-04-23T06:07:15.971007Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def create_embeddings(text_chunks):\n",
    "    \"\"\"\n",
    "    Creates embeddings for each text chunk.\n",
    "\n",
    "    Args:\n",
    "    text_chunks (List[str]): List of text chunks.\n",
    "\n",
    "    Returns:\n",
    "    List[np.ndarray]: List of embedding vectors.\n",
    "    \"\"\"\n",
    "    # Generate embeddings for each text chunk using the get_embedding function\n",
    "    return [get_embedding(chunk) for chunk in text_chunks]\n",
    "\n",
    "# Create chunk embeddings using the create_embeddings function\n",
    "chunk_embeddings = create_embeddings(text_chunks)"
   ],
   "id": "bea22ece9a451670",
   "outputs": [],
   "execution_count": 8
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 语义搜索\n",
    "\n",
    "余弦相似度来检索最相关的片段"
   ],
   "id": "89aac3ef583e2564"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:08:10.846693Z",
     "start_time": "2025-04-23T06:08:10.842133Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def semantic_search(query, text_chunks, chunk_embeddings, k=5):\n",
    "    \"\"\"\n",
    "    查询找到最相关的文本块\n",
    "\n",
    "    Args:\n",
    "    query (str): Search query.\n",
    "    text_chunks (List[str]): List of text chunks.\n",
    "    chunk_embeddings (List[np.ndarray]): List of chunk embeddings.\n",
    "    k (int): Number of top results to return.\n",
    "\n",
    "    Returns:\n",
    "    List[str]: Top-k relevant chunks.\n",
    "    \"\"\"\n",
    "    # 为查询生成嵌入\n",
    "    query_embedding = get_embedding(query)\n",
    "\n",
    "    # 计算查询嵌入与每个块嵌入之间的余弦相似度\n",
    "    similarities = [cosine_similarity(query_embedding, emb) for emb in chunk_embeddings]\n",
    "\n",
    "    # 获取最相似的 k 个块的索引\n",
    "    top_indices = np.argsort(similarities)[-k:][::-1]\n",
    "\n",
    "    # 返回最相关的 k 个文本块\n",
    "    return [text_chunks[i] for i in top_indices]"
   ],
   "id": "18ff2db60dee6797",
   "outputs": [],
   "execution_count": 9
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:08:31.171414Z",
     "start_time": "2025-04-23T06:08:30.777110Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Load the validation data from a JSON file\n",
    "with open('../../data/val.json', encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# Extract the first query from the validation data\n",
    "query = data[0]['question']\n",
    "\n",
    "# Get top 2 relevant chunks\n",
    "top_chunks = semantic_search(query, text_chunks, chunk_embeddings, k=2)\n",
    "\n",
    "# Print the query\n",
    "print(f\"Query: {query}\")\n",
    "\n",
    "# Print the top 2 most relevant text chunks\n",
    "for i, chunk in enumerate(top_chunks):\n",
    "    print(f\"Context {i+1}:\\n{chunk}\\n{'='*40}\")"
   ],
   "id": "544268c8c46c85b9",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: 什么是‘可解释人工智能’，为什么它被认为很重要？\n",
      "Context 1:\n",
      "\n",
      "透明度和可解释性\n",
      "透明度和可解释性对于建⽴对⼈⼯智能系统的信任⾄关重要。\n",
      "========================================\n",
      "Context 2:\n",
      "\n",
      "透明度和可解释性\n",
      "许多⼈⼯智能系统，尤其是深度学习模型，都是“⿊匣⼦”，很难理解它们是如何做出决策的。\n",
      "========================================\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 基于检索到的片段生成响应\n",
   "id": "f4a262d79472d5a0"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:08:54.587567Z",
     "start_time": "2025-04-23T06:08:47.882788Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Define the system prompt for the AI assistant\n",
    "system_prompt = \"你是一个AI助手，严格根据给定的上下文进行回答。如果无法直接从提供的上下文中得出答案，请回复：'我没有足够的信息来回答这个问题。'\"\n",
    "\n",
    "def generate_response(system_prompt, user_message):\n",
    "    \"\"\"\n",
    "    Generates a response from the AI model based on the system prompt and user message.\n",
    "\n",
    "    Args:\n",
    "    system_prompt (str): The system prompt to guide the AI's behavior.\n",
    "    user_message (str): The user's message or query.\n",
    "\n",
    "    Returns:\n",
    "    dict: The response from the AI model.\n",
    "    \"\"\"\n",
    "\n",
    "    response = client.chat.completions.create(\n",
    "            model=os.getenv(\"LLM_MODEL_ID\"),\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": system_prompt},\n",
    "                {\"role\": \"user\", \"content\": user_message}\n",
    "            ],\n",
    "            temperature=0.1,\n",
    "            top_p=0.8,\n",
    "            presence_penalty=1.05,\n",
    "            max_tokens=4096,\n",
    "        )\n",
    "    return response.choices[0].message.content\n",
    "\n",
    "# Create the user prompt based on the top chunks\n",
    "user_prompt = \"\\n\".join([f\"上下文内容 {i + 1}:\\n{chunk}\\n=====================================\\n\" for i, chunk in enumerate(top_chunks)])\n",
    "user_prompt = f\"{user_prompt}\\n问题: {query}\"\n",
    "\n",
    "# Generate AI response\n",
    "ai_response = generate_response(system_prompt, user_prompt)\n",
    "print(ai_response)"
   ],
   "id": "a5c7356b92330ba0",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'可解释人工智能'（Explainable AI, XAI）指的是那些能够向用户清晰地解释其决策过程和逻辑的人工智能系统。这种解释可以是关于模型如何处理输入数据、做出特定决策的原因，以及这些决策背后的推理过程。\n",
      "\n",
      "它被认为很重要，主要有以下几个原因：\n",
      "\n",
      "1. **建立信任**：当用户能够理解AI系统是如何工作时，他们更有可能信任这些系统的决策。如上下文内容1所述，透明度和可解释性对于建立对人工智能系统的信任至关重要。\n",
      "\n",
      "2. **安全性**：了解AI的决策过程有助于识别和纠正潜在的错误或偏差，从而提高系统的安全性和可靠性。\n",
      "\n",
      "3. **合规性**：在某些行业（如医疗、金融等），法规可能要求AI系统提供决策的解释，以确保其符合法律和伦理标准。\n",
      "\n",
      "4. **改进和优化**：通过理解AI的决策逻辑，开发者可以更好地优化模型，提高其性能和准确性。\n",
      "\n",
      "5. **用户接受度**：用户更愿意接受那些他们能够理解和控制的AI系统，这有助于推广AI技术的应用。\n",
      "\n",
      "综上所述，可解释人工智能对于提升AI系统的透明度、信任度、安全性和合规性等方面都具有重要意义。\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 评估人工智能响应",
   "id": "1b3acb144a0b8e0"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-23T06:09:44.552492Z",
     "start_time": "2025-04-23T06:09:36.187568Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Define the system prompt for the evaluation system\n",
    "evaluate_system_prompt = \"你是一个智能评估系统，负责评估AI助手的回答。如果AI助手的回答与真实答案非常接近，则评分为1。如果回答错误或与真实答案不符，则评分为0。如果回答部分符合真实答案，则评分为0.5。\"\n",
    "\n",
    "# Create the evaluation prompt by combining the user query, AI response, true response, and evaluation system prompt\n",
    "evaluation_prompt = f\"用户问题: {query}\\nAI回答:\\n{ai_response}\\nTrue Response: {data[0]['ideal_answer']}\\n{evaluate_system_prompt}\"\n",
    "\n",
    "# Generate the evaluation response using the evaluation system prompt and evaluation prompt\n",
    "evaluation_response = generate_response(evaluate_system_prompt, evaluation_prompt)\n",
    "print(evaluation_response)"
   ],
   "id": "42c9389447728cdf",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "根据提供的AI回答和真实答案，AI助手的回答与真实答案非常接近，涵盖了可解释人工智能（XAI）的定义、重要性及其多个方面的应用。具体来说：\n",
      "\n",
      "1. **定义**：AI回答正确地解释了XAI是能够清晰解释决策过程和逻辑的人工智能系统。\n",
      "2. **重要性**：\n",
      "   - **建立信任**：AI回答提到了透明度和可解释性对于建立信任的重要性，与真实答案一致。\n",
      "   - **安全性**：AI回答提到了通过了解决策过程来提高系统的安全性和可靠性，这与真实答案中的问责制有一定关联。\n",
      "   - **合规性**：AI回答提到了法规要求解释决策，确保符合法律和伦理标准，这与真实答案中的公平性有一定关联。\n",
      "   - **改进和优化**：AI回答提到了通过理解决策逻辑来优化模型，这是对XAI重要性的补充说明。\n",
      "   - **用户接受度**：AI回答提到了用户更愿意接受可理解的AI系统，这与建立信任相关。\n",
      "\n",
      "虽然AI回答没有直接提到“问责制”和“公平性”，但其提到的多个方面（如安全性、合规性）与这些概念有间接关联，且整体上涵盖了真实答案的核心要点。\n",
      "\n",
      "因此，AI助手的回答与真实答案非常接近，评分为1。\n"
     ]
    }
   ],
   "execution_count": 13
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
